#include <iostream>
#include <string>
#include <opencv2/opencv.hpp>

using namespace std;

void generateRotation(const float Rx, const float Ry, const float Rz, cv::Mat &Rotx, cv::Mat &Roty, cv::Mat &Rotz)
{
	Rotx = cv::Mat::eye(3, 3, CV_32FC1);
	Roty = cv::Mat::eye(3, 3, CV_32FC1);
	Rotz = cv::Mat::eye(3, 3, CV_32FC1);

	Rotx.at<float>(1, 1) = cos(Rx);
	Rotx.at<float>(1, 2) = -sin(Rx);
	Rotx.at<float>(2, 1) = sin(Rx);
	Rotx.at<float>(2, 2) = cos(Rx);

	Roty.at<float>(0, 0) = cos(Ry);
	Roty.at<float>(0, 2) = sin(Ry);
	Roty.at<float>(2, 0) = -sin(Ry);
	Roty.at<float>(2, 2) = cos(Ry);

	Rotz.at<float>(0, 0) = cos(Rz);
	Rotz.at<float>(0, 1) = -sin(Rz);
	Rotz.at<float>(1, 0) = sin(Rz);
	Rotz.at<float>(1, 1) = cos(Rz);
}


void pano2Sphere(const cv::Mat panoImg, cv::Mat &sphereImg, const cv::Point3f &center, const cv::Mat Rot)
{
	int panoWidth = panoImg.cols;
	int panoHeight = panoImg.rows;
	int spWidth = sphereImg.cols;
	int spHeight = sphereImg.rows;
	int channel = panoImg.channels();

	int cx = spWidth*0.5;
	int cy = spHeight*0.5;

	int radius = min(cx, cy);

	uchar *panoData = (uchar*)panoImg.data;
	uchar *sphereData = (uchar*)sphereImg.data;
	float *pRot = (float*)Rot.data;
	for (int i = 0; i < spHeight; i++)
	{
		for (int j = 0; j < spWidth; j++)
		{
			float y0 = cy - j;
			float x0 = i - cx;
			 
			float r = sqrt(x0*x0 + y0*y0);
			float theta = r*180*0.5 / radius;
			if (theta > 90 || theta < 0) continue;
			float z0 = radius*cos(theta * M_PI / 180);
			x0 /= radius;
			y0 /= radius;
			z0 /= radius;

			float x1 = x0*pRot[0] + y0*pRot[1] + z0*pRot[2];
			float y1 = x0*pRot[3] + y0*pRot[4] + z0*pRot[5];
			float z1 = x0*pRot[6] + y0*pRot[7] + z0*pRot[8];

			float theta2 = acos(z1);
			float fi2 = atan2(y1, x1);
			if (fi2 < 0)
			{
				fi2 += 2 * M_PI;
			}

			float v = panoHeight * theta2 / M_PI;
			float u = (2 * M_PI - fi2)* panoWidth / (2 * M_PI);


			//bilinear interp
			int srcU0 = floor(u);
			int srcV0 = floor(v);

			int srcU1 = srcU0 + 1;
			int srcV1 = srcV0 + 1;

			srcU0 = (srcU0 > panoWidth - 1) ? (panoWidth - 1) : ((srcU0 < 0) ? (0) : srcU0);
			srcV0 = (srcV0 > panoHeight - 1) ? (panoHeight - 1) : ((srcV0 < 0) ? (0) : srcV0);
			srcU1 = (srcU1 > panoWidth - 1) ? (panoWidth - 1) : ((srcU1 < 0) ? (0) : srcU1);
			srcV1 = (srcV1 > panoHeight - 1) ? (panoHeight - 1) : ((srcV1 < 0) ? (0) : srcV1);

			float dx = u - srcU0;
			float dy = v - srcV0;

			float w0 = (1 - dx)*(1 - dy);
			float w1 = dx*(1 - dy);
			float w2 = (1 - dx)*dy;
			float w3 = dx*dy;

			int dstIdx = (i*spWidth + j)*channel;
			for (int chan = 0; chan < channel; chan++)
			{
				sphereData[dstIdx + chan] = w0 * panoData[(srcV0*panoWidth + srcU0)*channel + chan]
					+ w1 * panoData[(srcV0*panoWidth + srcU1)* channel + chan]
					+ w2 * panoData[(srcV1*panoWidth + srcU0)*channel + chan]
					+ w3 * panoData[(srcV1*panoWidth + srcU1)*channel + chan];
			}
		}
	}
}

void main()
{
	std::string imgPath = "data/pano.jpg";
	cv::Mat panoImg = cv::imread(imgPath);

	cv::Mat sphereImg = cv::Mat::zeros(1000, 1000, CV_8UC3);

	
	float angleX = 0;
	float angleY = 0; //90 * M_PI / 180;
	float angleZ = 0;// 120 * M_PI / 180;
	cv::Mat Rotx, Roty, Rotz, Rot;
	generateRotation(angleX, angleY, angleZ, Rotx, Roty, Rotz);
	Rot = Rotx*Rotz*Roty;

	cv::Point3f projCenter(0, 0, 1);
	pano2Sphere(panoImg, sphereImg, projCenter, Rot);

	cv::imwrite("data/sphere.png", sphereImg);

	
}